---
title: "PAPEA Evaluate Predictions"
author: "Sebastian Haunss, Priska Daphi, Jan Matti Dollbaum, Lidiya Hristova, Pál Susánszky, Elias Steinhilper"
date: "`r Sys.Date()`"
output: 
  html_document:
    theme: cerulean
    toc: yes
    toc_float:
      collapsed: true
---

```{r setup, include=FALSE}
knitr::opts_chunk$set(echo = TRUE)
```


This script takes the predictions from model training and evaulation and computes the precision, accuracy, and f1-scores at the sentence and article level.

### Load necessary libraries

```{r, message=FALSE}
library(tidyverse)
library(ggplot2)
library(kableExtra)
```


## Evaluate Form Predictions

### Read data

```{r, message=FALSE}
# read predicted data
protest_forms_predicted <- read_csv(file = "../data/fgz_papea_form_predictions.csv")

# read manually annotated data
protest_forms_annotated <- read_csv(file = "../data/fgz_papea_forms.csv")
```

### Define function to compute f1 scores, precision and accuracy

```{r}
# create a custom evaluation function
f1_eval_forms <- function(actual,predicted){
  u <- sort(union(actual, predicted))
  cm = as.matrix(table(factor(actual, u), factor(predicted, u))) # create the confusion matrix
  n = sum(cm) # number of instances
  nc = nrow(cm) # number of classes
  diag = diag(cm) # number of correctly classified instances per class 
  rowsums = apply(cm, 1, sum) # number of instances per class
  colsums = apply(cm, 2, sum) # number of predictions per class
  p = rowsums / n # distribution of instances over the actual classes
  q = colsums / n # distribution of instances over the predicted classes
  accuracy = sum(diag) / n
  precision = diag / colsums 
  recall = diag / rowsums 
  f1 = 2 * precision * recall / (precision + recall) 
  macroPrecision = mean(precision, na.rm = T)
  macroRecall = mean(recall, na.rm = T)
  macroF1 = mean(f1, na.rm = T)
  df_eval <- rbind(data.frame(precision, recall, f1), data.frame(precision=macroPrecision, recall=macroRecall, f1=macroF1, row.names="macro"))
  
  return(df_eval)
}
```

### Prepare data for evaluation

```{r}
# rearrange manually annotated data (pivot wider)
forms_per_article <- protest_forms_annotated %>% 
  select(fid, form) %>% 
  unique() %>% 
  group_by(fid) %>% 
  mutate(cid = paste0("F",row_number())) %>% 
  pivot_wider(values_from = form, names_from = cid) %>% 
  ungroup() %>% 
  arrange(fid)

# count predictions per form per article
predicted_forms <- protest_forms_predicted %>% 
  select(fid, form=pred_form) %>% 
  group_by(fid, form) %>% 
  mutate(form_count = n()) %>% 
  slice(1) %>% 
  arrange(fid) %>% 
  filter(form > 0)

# join predicted forms and manually annotated form data
forms_comparison <- left_join(predicted_forms, forms_per_article, by="fid") %>% 
  mutate(eval = 0, .before="F1") %>% 
  group_by(fid) %>% 
  mutate(n_found_forms = n(), .before="F1") %>% 
  ungroup()

# check whether the predicted forms match any of the actual forms in the respective article
for (i in 1:dim(forms_comparison)[1]) {
  forms_comparison$eval[i] <- 
    case_when((forms_comparison$form[i] %in% forms_comparison[i,6:11]) ~ 1,
              .default = 0)
}

```

### Evaluate at form level

```{r}
# 1) evaluate whether the prediction matches the first annotated form
eval1 <- forms_comparison %>% 
  mutate(actual = case_when(eval == 1 ~ form,
                            .default = F1)) %>% 
  select(predicted = form, actual)

df_eval1 <- f1_eval_forms(eval1$actual, eval1$predicted)
df_eval1$n <- c(table(eval1$actual),NA)

df_eval1 %>% 
  kbl(col.names = c("category", names(df_eval1)), digits=2) %>%
  kable_styling(full_width = F)
```

### Evaluate at the level of frequently predicted forms

```{r}
# 2) evaluate whether the more often found forms match with annotated forms
# create dataframe with frequently found forms (>2)
hf_forms <- forms_comparison %>% 
    filter(form_count > 2)
hf_fids <- hf_forms$fid %>% 
  unique()

# from the rest keep only the most often found forms
lf_forms <- forms_comparison %>% 
  filter(!fid %in% hf_fids) %>% 
  group_by(fid) %>% 
  arrange(desc(form_count), .by_group = T) %>% 
  slice(1)

# re-combine the two data frames
eval_forms <- rbind(hf_forms, lf_forms)

eval2 <- eval_forms %>% 
  mutate(actual = case_when(eval == 1 ~ form,
                            .default = F1)) %>% 
  select(predicted = form, actual)

df_eval2 <- f1_eval_forms(eval2$actual, eval2$predicted)
df_eval2$n <- c(table(eval2$actual),NA)


df_eval2 %>% 
  kbl(col.names = c("category", names(df_eval2)), digits=2) %>%
  kable_styling(full_width = F)
```

### Evaluate at the article level
(this is the data used for table 1b in the article)

``` {r}
# 3) aggregate at article level
predicted_forms_per_article <- forms_comparison %>% 
  select(fid, form, form_count, eval, F1) %>% 
  rowwise() %>% 
  mutate(eval_sum = sum(eval*form_count)) %>% 
  group_by(fid) %>% 
  arrange(desc(eval_sum), .by_group = T) %>% 
  slice(1) %>% 
  ungroup()

eval3 <- predicted_forms_per_article %>% 
  mutate(actual = case_when(eval == 1 ~ form,
                            .default = F1)) %>% 
  select(predicted = form, actual)

df_eval3 <- f1_eval_forms(eval3$actual, eval3$predicted)
df_eval3$n <- c(table(eval3$actual),NA)

df_eval3 %>% 
  kbl(col.names = c("category", names(df_eval3)), 
      digits=2,
      caption = "Table 1b, right panel") %>%
  kable_styling(full_width = F)
```

### Additional checks

#### Check confusion matrix
``` {r}
u3 <- sort(union(eval3$actual, eval3$predicted))
cm_article_level <-as.matrix(table(factor(eval3$actual, u3), factor(eval3$predicted, u3))) # create the confusion matrix
  

cf <- caret::confusionMatrix(data=factor(eval3$predicted, u3),
                             reference=factor(eval3$actual, u3))
print(cf)
```

#### Correct predictions per article

``` {r}
# correct predictions
all_classifications <- data.frame(table(predicted_forms_per_article$eval_sum))
all_classifications %>% 
  mutate(fill.var = ifelse(Var1==0, 1,0)) %>% 
  ggplot(aes(y=Freq, x=Var1, fill=fill.var)) + 
  geom_bar(stat="identity", show.legend=FALSE) +
  theme_bw() +
  ylab("Number of Cases") +
  xlab("Number of Correct Predictions per Article")
```

#### incorrect predictions


``` {r}
# incorrect predictions
wrong_classifications <- data.frame(table(predicted_forms_per_article$form[predicted_forms_per_article$eval_sum == 0]))

ggplot(wrong_classifications, aes(y=Freq, x=factor(Var1))) + 
  geom_bar(stat="identity") +
  theme_bw() +
  ylab("Number of Cases") +
  xlab("Categories")
```

---

## Evaluate Claim Predictions

### Read data

```{r, message=FALSE}
# read predicted data
protest_claims_predicted <- read_csv(file = "../data/fgz_papea_claim_predictions.csv")

# read manually annotated data
protest_claims_annotated <- read_csv(file = "../data/fgz_papea_claims.csv")

```

### Define function to compute f1 scores, precision and accuracy

```{r}
# create a custom evaluation function that accounts for unpredicted claims
unpredicted_claims = c(107, 1540, 9900) # account for claims that were never predicted

f1_eval_claims <- function(actual,predicted){
  unpredicted_claims = c(107, 1540, 9900) # account for claims that were never predicted
  u <- sort(union(actual, predicted)) %>% setdiff(unpredicted_claims)
  cm = as.matrix(table(factor(actual, u), factor(predicted, u))) # create the confusion matrix
  n = sum(cm) # number of instances
  nc = nrow(cm) # number of classes
  diag = diag(cm) # number of correctly classified instances per class 
  rowsums = apply(cm, 1, sum) # number of instances per class
  colsums = apply(cm, 2, sum) # number of predictions per class
  p = rowsums / n # distribution of instances over the actual classes
  q = colsums / n # distribution of instances over the predicted classes
  accuracy = sum(diag) / n
  precision = diag / colsums 
  recall = diag / rowsums 
  f1 = 2 * precision * recall / (precision + recall) 
  macroPrecision = mean(precision, na.rm = T)
  macroRecall = mean(recall, na.rm = T)
  macroF1 = mean(f1, na.rm = T)
  df_eval <- rbind(data.frame(precision, recall, f1), data.frame(precision=macroPrecision, recall=macroRecall, f1=macroF1, row.names="macro"))
  
  return(df_eval)
}

```

### Prepare data for evaluation

```{r}
claims_per_article <- protest_claims_annotated %>% 
  select(fid, claim) %>% 
  unique() %>% 
  group_by(fid) %>% 
  mutate(cid = paste0("C",row_number())) %>% 
  pivot_wider(values_from = claim, names_from = cid) %>% 
  ungroup() %>% 
  arrange(fid)

predicted_claims <- protest_claims_predicted %>% 
  select(fid, claim=prediction_claim) %>% 
  group_by(fid, claim) %>% 
  mutate(claim_count = n()) %>% 
  slice(1) %>% 
  arrange(fid)


# join predicted claims and actual claim data
claim_comparison <- left_join(predicted_claims, claims_per_article, by="fid") %>% 
  mutate(eval = 0, .before="C1") %>% 
  group_by(fid) %>% 
  mutate(n_found_claims = n(), .before="C1") %>% 
  ungroup()

# check whether the predicted claims match any of the actual claims in the respective article
for (i in 1:dim(claim_comparison)[1]) {
  claim_comparison$eval[i] <- 
    case_when((claim_comparison$claim[i] %in% claim_comparison[i,6:12]) ~ 1,
              .default = 0)
}

```

### Evaluate at claim level

```{r}
eval1 <- claim_comparison %>% 
  mutate(actual = case_when(eval == 1 ~ claim,
                            .default = C1)) %>% 
  select(predicted = claim, actual)


df_eval1 <- f1_eval_claims(eval1$actual, eval1$predicted)
df_eval1$n <- c(table(eval1$predicted),NA)

df_eval1 %>% 
  kbl(col.names = c("category", names(df_eval1)), digits=2) %>%
  kable_styling(full_width = F)
```

### Evaluate at the level of frequently used claims

```{r}
# create dataframe with frequently found claims (>2)
hf_claims <- claim_comparison %>% 
    filter(claim_count > 2)

hf_fids <- hf_claims$fid %>% 
  unique()

# from the rest keep only the most often found claims
lf_claims <- claim_comparison %>% 
  filter(!fid %in% hf_fids) %>% 
  group_by(fid) %>% 
  arrange(desc(claim_count), .by_group = T) %>% 
  slice(1)

# re-combine the two data frames
eval_claims <- rbind(hf_claims, lf_claims)

eval2 <- eval_claims %>% 
  mutate(actual = case_when(eval == 1 ~ claim,
                            .default = C1)) %>% 
  select(predicted = claim, actual)

df_eval2 <- f1_eval_claims(eval2$actual, eval2$predicted)
df_eval2$n <- c(table(factor(eval2$predicted, levels=sort(union(eval2$actual, eval2$predicted)) %>% setdiff(unpredicted_claims))),NA)


df_eval2 %>% 
  kbl(col.names = c("category", names(df_eval2)), digits=2) %>%
  kable_styling(full_width = F)
```

### Evaluate at the article level
(This is the data used for table 2 in the article)

``` {r}
# aggregate at article level
predicted_claims_per_article <- claim_comparison %>% 
  select(fid, claim, claim_count, eval, C1) %>% 
  rowwise() %>% 
  mutate(eval_sum = sum(eval*claim_count)) %>% 
  group_by(fid) %>% 
  arrange(desc(eval_sum), .by_group = T) %>% 
  slice(1) %>% 
  ungroup()

eval3 <- predicted_claims_per_article %>% 
  mutate(actual = case_when(eval == 1 ~ claim,
                            .default = C1)) %>% 
  select(predicted = claim, actual)

df_eval3 <- f1_eval_claims(eval3$actual, eval3$predicted)
df_eval3$n <- c(table(factor(eval3$predicted, levels=sort(union(eval3$actual, eval3$predicted)) %>% setdiff(unpredicted_claims))),NA)

df_eval3 %>% 
  kbl(col.names = c("category", names(df_eval3)), 
      digits=2,
      caption = "Table 2, right panel") %>%
  kable_styling(full_width = F)
```

### Additional checks

#### Check confusion matrix
``` {r}
u3 <- sort(union(eval3$actual, eval3$predicted)) %>% setdiff(unpredicted_claims)
cm_article_level <-as.matrix(table(factor(eval3$actual, u3), factor(eval3$predicted, u3))) # create the confusion matrix
  

cf <- caret::confusionMatrix(data=factor(eval3$predicted, u3),
                             reference=factor(eval3$actual, u3))
print(cf)
```

#### Correct predictions per article

``` {r}
# correct predictions
all_classifications <- data.frame(table(predicted_claims_per_article$eval_sum))
all_classifications %>% 
  mutate(fill.var = ifelse(Var1==0, 1,0)) %>% 
  ggplot(aes(y=Freq, x=Var1, fill=fill.var)) + 
  geom_bar(stat="identity", show.legend=FALSE) +
  theme_bw() +
  ylab("Number of Cases") +
  xlab("Number of Correct Predictions per Article")
```

#### incorrect predictions


``` {r}
# incorrect predictions
wrong_classifications <- data.frame(table(predicted_claims_per_article$claim[predicted_claims_per_article$eval_sum == 0]))

ggplot(wrong_classifications, aes(y=Freq, x=factor(Var1))) + 
  geom_bar(stat="identity") +
  theme_bw() +
  ylab("Number of Cases") +
  xlab("Categories")
```

